Skip to content

[DSv4][Nvidia] SM12x DeepSeek V4 support#40991

Closed
jasl wants to merge 11 commits intovllm-project:mainfrom
jasl:ds4-sm120
Closed

[DSv4][Nvidia] SM12x DeepSeek V4 support#40991
jasl wants to merge 11 commits intovllm-project:mainfrom
jasl:ds4-sm120

Conversation

@jasl
Copy link
Copy Markdown
Contributor

@jasl jasl commented Apr 27, 2026

The PR combines #40929, now it's DeepGEMM free, thanks to @bbbearxyz !

UPDATE: To better aligh with Deepseek official API and the B200 code path, I made a harness to help to measure correctness, performance, and quality https://github.com/jasl/vllm-ds4-sm120-harness
And I will put the latest report for people to review

Summary

This PR enables DeepSeek V4 Flash to serve on NVIDIA SM12x GPUs, tested on a
2x RTX PRO 6000 Blackwell Workstation Edition host.

The important change from the earlier prototype is that this PR no longer pins
or rewrites the DeepGEMM dependency. The branch keeps vLLM's upstream DeepGEMM
installer and CMake metadata intact, and implements the required SM12x runtime
fallbacks in vLLM:

  • DeepSeek V4 tokenizer / parser / model integration.
  • Portable Triton sparse MLA path for SM12x.
  • fp8_ds_mla sparse MLA cache handling.
  • Sink-aware SWA + compressed sparse attention.
  • vLLM-side SM12x fallbacks for DeepSeek V4-specific DeepGEMM calls.
  • SM12x sparse indexer and paged MQA fallback kernels.
  • Guardrails so existing SM90 / SM100 optimized paths remain unchanged.

Motivation

DeepSeek V4 currently relies on kernels that are available on Hopper and
datacenter Blackwell paths, but not on SM120 / SM121 workstation and consumer
Blackwell GPUs. In particular, SM12x cannot directly reuse SM90 WGMMA kernels
or SM100 tcgen05 kernels.

This PR adds correctness-first portable kernels for the missing SM12x pieces,
then optimizes the hot sparse MLA paths enough for real serving. The result is
a reviewable vLLM-side compatibility layer that does not require maintainers to
accept a temporary DeepGEMM fork pin.

Scope

Included:

  • SM12x Triton sparse MLA decode and prefill paths.
  • fp8_ds_mla packed cache decode for SWA and compressed sparse candidates.
  • Sink-aware sparse attention denominator semantics.
  • SM12x local fallbacks for DeepSeek V4-specific DeepGEMM call sites.
  • Sparse indexer memory bound fixes for long prefill.
  • DeepSeek V4 tokenizer handling and tool-call parser fixes needed by the new
    model path.
  • Targeted correctness tests and an HTTP logprobs oracle comparator.

Not included:

  • Replacing FlashMLA on SM90 / SM100.
  • A final Tensor Core implementation for every SM12x kernel.
  • MTP speculative decoding fixes. Those are kept in a separate branch / PR.
  • Community performance experiments that are useful for evaluation but too
    broad for this PR.
  • Any DeepGEMM fork pin or DeepGEMM CMake / install-script rewrite.

Runtime controls

The SM12x sparse MLA path registers its environment variables in vllm.envs,
so users should not see unknown-variable warnings for these knobs.

Variable Default Meaning
VLLM_TRITON_MLA_SPARSE auto 1 forces the Triton sparse MLA path, 0 disables it. When unset, vLLM enables it on SM12x where FlashMLA sparse is unavailable.
VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE 512 Top-k candidate chunk size for sparse MLA accumulation. Lower values reduce transient workspace at the cost of more kernel work.
VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE 256 Query chunk size used by prefill sparse MLA fallback.
VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE auto Optional decode head block override. Supported values are 1, 2, and 4; benchmarks used 4.
VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE auto Optional matmul-based sparse MLA decode toggle. When unset it auto-enables on SM12x.
VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH context dependent Allows compile / CUDA graphs for the sparse MLA path. In the formal PR branch, unset keeps graphs for normal decode and disables them for speculative decoding; 1 forces allow, 0 disables.

Operational warning: do not set
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True with the TP=2 CUDA graph
configuration used below. In local testing it made custom all-reduce fail
during CUDA graph address registration. Leaving it unset avoids that failure.

Branches

Formal PR branch:

jasl/vllm@ds4-sm120
HEAD: 7a34ed538

Preview / evaluation branch with extra community performance work and MTP fixes:

jasl/vllm@ds4-sm120-full
HEAD: ab7336f21

The preview branch is not intended as the review target. It exists so users can
try the broader optimization stack while this PR stays focused.

Test environment

Hardware:

Host: jasl-workstation
GPU: 2x NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Compute capability: SM120
GPU memory: 95 GiB class per GPU

Software:

OS: Ubuntu, Linux 7.0.0-14-generic
CUDA toolkit: /usr/local/cuda
Python: 3.13.13
PyTorch: 2.11.0+cu130
vLLM package metadata: 0.20.1rc1.dev12+g363ffa145

Benchmark environment:

export PATH="/usr/local/cuda/bin:$PATH"
export CUDA_HOME="/usr/local/cuda"
export TRITON_PTXAS_PATH="/usr/local/cuda/bin/ptxas"
export CUDA_ARCH_LIST="120a"
export TORCH_CUDA_ARCH_LIST="12.0a"
export VLLM_TRITON_MLA_SPARSE=1
export VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE=4
export VLLM_RPC_TIMEOUT=100000
unset PYTORCH_CUDA_ALLOC_CONF

Note: DGX Spark use 121a and 12.1a

Validation

Formal PR branch checks:

python -m ruff check \
  vllm/envs.py \
  vllm/utils/deep_gemm.py \
  vllm/tokenizers/deepseek_v4_encoding.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  vllm/v1/attention/backends/mla/sparse_mla_env.py \
  vllm/v1/attention/backends/mla/sparse_swa.py \
  tests/tokenizers_/test_deepseek_v4.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py \
  tests/v1/attention/test_sm120_deepgemm_fallbacks.py

Result:

All checks passed!

Compile check:

python -m py_compile \
  vllm/envs.py \
  vllm/utils/deep_gemm.py \
  vllm/tokenizers/deepseek_v4_encoding.py \
  vllm/v1/attention/backends/mla/sparse_mla_kernels.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  vllm/v1/attention/backends/mla/sparse_swa.py

Targeted tests:

python -m pytest -q \
  tests/tokenizers_/test_deepseek_v4.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_sparse_mla_backends.py \
  tests/v1/attention/test_sm120_deepgemm_fallbacks.py \
  tests/v1/attention/test_sparse_attn_indexer.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py

Result:

151 passed, 504 skipped, 16 warnings in 356.93s

Diff hygiene:

git diff --check origin/main...HEAD

Result: clean.

Preview branch focused checks:

python -m ruff check \
  vllm/v1/attention/backends/mla/sparse_mla_env.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  tests/v1/spec_decode/test_mtp.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py

python -m pytest -q \
  tests/v1/spec_decode/test_mtp.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py

Result:

95 passed, 16 warnings in 48.35s

Serving command

Formal PR branch, no MTP:

PYTHONPATH=~/tmp/vllm-bench-ds4-sm120 \
~/tmp/vllm/.venv/bin/vllm serve deepseek-ai/DeepSeek-V4-Flash \
  --host 127.0.0.1 \
  --port 8017 \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --max-model-len 16384 \
  --gpu-memory-utilization 0.94 \
  --tensor-parallel-size 2 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4

Preview branch, MTP:

PYTHONPATH=~/tmp/vllm-bench-ds4-sm120-full \
~/tmp/vllm/.venv/bin/vllm serve deepseek-ai/DeepSeek-V4-Flash \
  --host 127.0.0.1 \
  --port 8018 \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --max-model-len 16384 \
  --gpu-memory-utilization 0.985 \
  --tensor-parallel-size 2 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4 \
  --speculative-config '{"method":"mtp","num_speculative_tokens":2}'

Benchmark command

The short-context benchmark uses 128 -> 512; the long-context benchmark uses
8192 -> 512. Each row uses 48 prompts and temperature=0.

~/tmp/vllm/.venv/bin/vllm bench serve \
  --model deepseek-ai/DeepSeek-V4-Flash \
  --host 127.0.0.1 \
  --port <port> \
  --dataset-name random \
  --random-input-len <128-or-8192> \
  --random-output-len 512 \
  --num-prompts 48 \
  --max-concurrency <C> \
  --ignore-eos \
  --temperature 0 \
  --save-result \
  --result-dir <result-dir> \
  --result-filename <name>.json

Formal PR branch benchmark

Branch:

jasl/vllm@ds4-sm120
HEAD: 7a34ed538

Server memory setting:

--gpu-memory-utilization 0.94

MTP is not included in this branch. Starting the formal branch with
--speculative-config '{"method":"mtp","num_speculative_tokens":2}' fails
because the MTP fix stack is intentionally kept separate.

Context Concurrency Output tok/s Requests/s Mean TPOT Mean TTFT
128 -> 512 1 100.38 0.196 9.76 ms 113.4 ms
128 -> 512 4 296.84 0.580 13.16 ms 171.9 ms
128 -> 512 8 478.34 0.934 16.18 ms 291.6 ms
8192 -> 512 1 58.61 0.114 10.94 ms 3143.0 ms
8192 -> 512 2 81.35 0.159 15.37 ms 4732.0 ms

Result directory:

/home/jasl/tmp/ds4_sm120_bench_20260429_032651

Preview branch benchmark

Branch:

jasl/vllm@ds4-sm120-full
HEAD: ab7336f21

Server memory setting:

--gpu-memory-utilization 0.985

This branch includes the separate MTP fixes and community performance patches.
It is for evaluation only, not the formal PR review target.

Startup notes:

  • no-MTP CUDA graph reserve: 3.67 GiB
  • no-MTP available KV cache: 10.6 GiB
  • MTP CUDA graph reserve: 4.38 GiB
  • MTP available KV cache: 6.2 GiB
Context Concurrency no-MTP tok/s MTP tok/s MTP delta no-MTP TPOT MTP TPOT no-MTP TTFT MTP TTFT MTP acceptance
128 -> 512 1 103.03 161.14 +56.4% 9.60 ms 5.95 ms 62.3 ms 138.7 ms 78.61%
128 -> 512 4 303.20 326.51 +7.7% 12.93 ms 11.47 ms 145.6 ms 346.0 ms 80.14%
128 -> 512 8 473.53 525.08 +10.9% 16.46 ms 14.07 ms 236.3 ms 402.2 ms 77.17%
8192 -> 512 1 58.54 79.17 +35.2% 10.81 ms 6.23 ms 3223.4 ms 3283.6 ms 81.48%
8192 -> 512 2 80.77 98.33 +21.7% 15.33 ms 13.46 ms 4843.8 ms 3486.3 ms 79.02%

Result directory:

/home/jasl/tmp/ds4_sm120_full_bench_20260429_041151

Review notes

Changes made before this update:

  • Removed the temporary DeepGEMM fork pin and related env bridge.
  • Removed sparse MLA diagnostic dump hooks and tests.
  • Kept runtime-facing names production-oriented; test oracle helpers remain
    clearly separated from serving kernels.
  • Verified there are no stale prototype DeepGEMM refs.
  • Re-signed the branch with DCO trailers.
  • Re-ran targeted tests and benchmarks after the cleanup.

Known follow-ups

  • MTP speculative decoding should be reviewed as an independent PR.
  • ds4-sm120-full can continue to carry community performance patches for
    public evaluation.
  • Further SM12x optimization should focus on full decode profiling across
    indexer, MoE, collectives, sampling, and sparse MLA rather than broadening
    this PR.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added ci/build deepseek Related to DeepSeek models nvidia v1 labels Apr 27, 2026
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

@WoosukKwon
I rebased my original PR #40899
Here it is, please help to review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DeepSeek V4 models, including updates to DeepGEMM integration, new FP8 einsum kernels for SM12x, and infrastructure for sparse MLA attention. However, there are two critical issues: the removal of the optional dependency check for tilelang in vllm/model_executor/layers/mhc.py will break installations on non-CUDA platforms, and the replacement of DeepseekV4MLP with DeepseekV2MLP for shared experts removes necessary swiglu_limit clamping, which is vital for numerical stability in FP8 inference.

Comment thread vllm/model_executor/layers/mhc.py Outdated
Comment thread vllm/model_executor/models/deepseek_v4.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4e2adf8a9f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm/model_executor/models/deepseek_v4.py Outdated
Comment thread vllm/model_executor/layers/mhc.py Outdated
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

The PR is ready to review.
I'm benchmarking the latest result.

@wuwenthink
Copy link
Copy Markdown

Thanks to your development, the length of the context currently supported locally has increased significantly, and the speed of decode has increased a lot. It's amazing!

@bbbearxyz
Copy link
Copy Markdown

@jasl My understanding is that your current approach supports SM120 through a combination of DeepGEMM and Triton. I wonder whether a pure Triton implementation, without depending on DeepGEMM at all, would be cleaner and perhaps worth considering as an alternative. I’d be interested to hear your thoughts.

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

@jasl My understanding is that your current approach supports SM120 through a combination of DeepGEMM and Triton. I wonder whether a pure Triton implementation, without depending on DeepGEMM at all, would be cleaner and perhaps worth considering as an alternative. I’d be interested to hear your thoughts.

I don't have a preference.
IMO, contributing to DeepGEMM would help align behavior with DeepSeek's official behavior and enable them to pay attention to the community's needs.
I can imagine pure Triton would help to improve performance. I can try it later.

jasl and others added 8 commits May 6, 2026 00:07
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Protect hybrid-aligned DeepSeek V4 MLA prompt cache blocks so they survive decode and unrelated long-session cache churn. Keep common-prefix accounting aware of the extra protection reference and cover compressor-state SlidingWindowMLA groups in a regression test.

Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
@v1b3coder
Copy link
Copy Markdown

@jasl The engine crashed with following errors while running latest ds4-sm120-full after few hours of random usage. Unfortunately I cannot replicate it. Btw, would you mind to allow Issues on your repository, so we can report in more structured ways, please?

(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138] Traceback (most recent call last):
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 1129, in run_engine_core
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     engine_core.run_busy_loop()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 1170, in run_busy_loop
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     self._process_engine_step()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 1209, in _process_engine_step
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     outputs, model_executed = self.step_fn()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                               ^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 471, in step_with_batch_queue
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     scheduler_output = self.scheduler.schedule()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                        ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/sched/scheduler.py", line 744, in schedule
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     new_blocks = self.kv_cache_manager.allocate_slots(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/kv_cache_manager.py", line 393, in allocate_slots
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     new_blocks = self.coordinator.allocate_new_blocks(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/kv_cache_coordinator.py", line 187, in allocate_new_blocks
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     return tuple(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]            ^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/kv_cache_coordinator.py", line 188, in <genexpr>
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     manager.allocate_new_blocks(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/single_type_kv_cache_manager.py", line 270, in allocate_new_blocks
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/block_pool.py", line 334, in get_new_blocks
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138] ValueError: Cannot get 2048 free blocks from the pool
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148] Traceback (most recent call last):
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/entrypoints/openai/chat_completion/serving.py", line 524, in chat_completion_stream_generator
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     async for res in result_generator:
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/async_llm.py", line 579, in generate
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     out = q.get_nowait() or await q.get()
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]                             ^^^^^^^^^^^^^
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/output_processor.py", line 85, in get
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     raise output
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/async_llm.py", line 660, in output_handler
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     outputs = await engine_core.get_output_async()
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/core_client.py", line 998, in get_output_async
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     raise self._format_exception(outputs) from None
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148] vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
(APIServer pid=1) INFO:     Shutting down
(APIServer pid=1) INFO:     Waiting for application shutdown.
(APIServer pid=1) INFO:     Application shutdown complete.
(APIServer pid=1) INFO:     Finished server process [1]
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Signed-off-by: jasl <jasl9187@hotmail.com>
doctorpangloss added a commit to doctorpangloss/forks-vllm-ampere that referenced this pull request May 5, 2026
Verified on 2x A5000 sm_86: MLA attention + DeepSeekMoE + bf16 produces
'The capital of France is Paris. The official language is French...' at
PP=2. Triton sparse-MLA from PR vllm-project#40991 + sm_8x gate works on Ampere.
Confirms rungs 2 of the model ladder for V4-Flash on Ampere.
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 5, 2026

@jasl The engine crashed with following errors while running latest ds4-sm120-full after few hours of random usage. Unfortunately I cannot replicate it. Btw, would you mind to allow Issues on your repository, so we can report in more structured ways, please?

(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138] Traceback (most recent call last):
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 1129, in run_engine_core
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     engine_core.run_busy_loop()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 1170, in run_busy_loop
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     self._process_engine_step()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 1209, in _process_engine_step
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     outputs, model_executed = self.step_fn()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                               ^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/engine/core.py", line 471, in step_with_batch_queue
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     scheduler_output = self.scheduler.schedule()
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                        ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/sched/scheduler.py", line 744, in schedule
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     new_blocks = self.kv_cache_manager.allocate_slots(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/kv_cache_manager.py", line 393, in allocate_slots
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     new_blocks = self.coordinator.allocate_new_blocks(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/kv_cache_coordinator.py", line 187, in allocate_new_blocks
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     return tuple(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]            ^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/kv_cache_coordinator.py", line 188, in <genexpr>
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     manager.allocate_new_blocks(
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/single_type_kv_cache_manager.py", line 270, in allocate_new_blocks
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]   File "/opt/jasl/vllm/vllm/v1/core/block_pool.py", line 334, in get_new_blocks
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138]     raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")
(EngineCore pid=253) ERROR 05-05 17:47:29 [core.py:1138] ValueError: Cannot get 2048 free blocks from the pool
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148] Traceback (most recent call last):
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/entrypoints/openai/chat_completion/serving.py", line 524, in chat_completion_stream_generator
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     async for res in result_generator:
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/async_llm.py", line 579, in generate
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     out = q.get_nowait() or await q.get()
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]                             ^^^^^^^^^^^^^
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/output_processor.py", line 85, in get
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     raise output
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/async_llm.py", line 660, in output_handler
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     outputs = await engine_core.get_output_async()
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]   File "/opt/jasl/vllm/vllm/v1/engine/core_client.py", line 998, in get_output_async
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148]     raise self._format_exception(outputs) from None
(APIServer pid=1) ERROR 05-05 17:47:29 [serving.py:1148] vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
(APIServer pid=1) INFO:     Shutting down
(APIServer pid=1) INFO:     Waiting for application shutdown.
(APIServer pid=1) INFO:     Application shutdown complete.
(APIServer pid=1) INFO:     Finished server process [1]
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

I'm doing a long run smoking test, not sure I can trigger it.

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 5, 2026

@v1b3coder
You can submit to my fork repo

@v1b3coder
Copy link
Copy Markdown

I don't see Issues available at https://github.com/jasl/vllm . Maybe you have other repo on mind?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 5, 2026

I don't see Issues available at https://github.com/jasl/vllm . Maybe you have other repo on mind?

I need to enable Issue feature myself, now you should see it.

Route SM12x sparse MLA decode metadata around DeepGEMM scheduler metadata instead of returning placeholder metadata. Let get_paged_mqa_logits_metadata call the backend normally so unexpected SM12x metadata calls fail through the backend.

Also keep SM12x FP8 MQA and paged MQA local fallback dispatch from initializing DeepGEMM before the SM12x guard runs.

Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
@v1b3coder
Copy link
Copy Markdown

Thank you, much appreciated. It crashed for second time now, I opened issue.

doctorpangloss added a commit to doctorpangloss/forks-vllm-ampere that referenced this pull request May 6, 2026
Five additional gates extended from SM12x-only to also include sm_8x
(Ampere/Ada). PR vllm-project#40991 author noted 'SM80/86/89 architectures could
theoretically use identical Triton approaches'; this validates that
claim across the V4-Flash forward path:

  - vllm/utils/deep_gemm.py:
    * fp8_mqa_logits dispatch (line ~626)
    * fp8_paged_mqa_logits dispatch (line ~924)
    * tf32_hc_prenorm_gemm dispatch (line ~1004)
    All three Triton kernels live in deepseek_v4_triton_kernels.py and
    are sm_80+ portable. DeepGEMM fallback is Hopper/Blackwell-only and
    raises _missing() on Ampere — without these gates V4 attention dies
    on the first token.

  - vllm/model_executor/layers/deepseek_v4_attention.py:
    _use_deepseek_v4_sm12x_triton_fp8_einsum widened to capability.major
    in (8, 12). DeepGEMM equivalent isn't available on Ampere.

  - vllm/model_executor/layers/mhc.py:
    Hyperconnections (mhc_pre/post/hc_head) used TileLang JIT on CUDA.
    TileLang requires sm_89+ — fails JIT compile on sm_80/sm_86. New
    helper _should_use_mhc_torch_fallback() routes torch reference impl
    on Ampere and ROCm. Numerically equivalent, ~1.5-2x slower.
@ehfd
Copy link
Copy Markdown
Contributor

ehfd commented May 6, 2026

jasl#3

@pasta-paul
Copy link
Copy Markdown

Independent validation on dual DGX Spark GB10 (SM 12.1a, 121 GiB UMA each), TP=2 over QSFP RDMA, against PR head 0789bc9 plus kylesayrs/deepseek-ct cherry-pick + a one-line packed_modules_mapping patch on DeepseekV4ForCausalLM.

Quant: pastapaul/DeepSeek-V4-Flash-W4A16-FP8 (compressed-tensors W4A16 routed-experts + FP8_BLOCK attention).

Validated graphs-ON (no --enforce-eager) at every context size we tested — boot is essentially flat across context sizes:

Config Boot Smoke decode NIAH retrieval (4 positions)
16K × 4-seqs 338 s 11.29 t/s not run
128K × 1-seq 308 s 12.07 t/s 4 / 4 at 100K-tok haystack
256K × 1-seq 306 s 9.44 t/s 4 / 4 at 200K
256K × 2-seqs 307 s 8.92 t/s 4 / 4 at 200K
500K × 1-seq 306 s 10.12 t/s smoke + boot confirmed

Mini-suite at 256K × 2 graphs-ON: 10 / 10 PASS (smoke 4/4 incl. tool-calling, generation 3 prompts × non-thinking + think-high).

Two observations:

Findings doc + raw evidence: pasta-paul/dsv4-flash-w4a16-fp8 findings/spark_tp2_deployment.md (Phase 4d).

Thank you for the SM12x work — this is the first config that actually serves long-context DSV4-Flash on consumer Blackwell.

Copy link
Copy Markdown
Member

@zyongye zyongye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all, thank you for your contribution.

Overall I think this PR is not ready to merge yet. There's too much changes unrelated to the model enablement. The tests are not well structured and is only testing meaningless things. I suggest we can do some cleanup before the next round review.

Regarding to changes to the core file, the only change we expect is too adding new kernels and branch out in necessary places. However, all kernel implementation should live in separate file so that we can keep the core file clean, (e.g. deepseek_v4_attention.py, sparse_attn_indexer.py). Similar suggestion is also proposed to AMD enablement.

I do suggest take a look at AMD enablement to see if the kernel can be overlapped from there.

get_paged_mqa_logits_metadata,
)
from vllm.utils.deep_gemm import (
fp8_fp4_paged_mqa_logits as fp8_paged_mqa_logits,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to change this? I do prefer keep the original name.

)


def _make_mega_moe_config(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we solely testing backend selection then we don't need to include this into the tests

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this is needed as well.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also no need to include this

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also no need to include.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding another branch in forward_cuda and create another function just for that. Looking at forward_hip

from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_ep_group,
get_pp_group,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PP support should come from different PR.

Comment thread vllm/utils/deep_gemm.py
_SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192


def _fp8_mqa_logits_head_chunk_size(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All non-deepgemm library operation should move to a different file. This is the file that solely to interface with deepgemm.

and is_triton_sparse_mla_enabled_for_platform()
and not triton_sparse_mla_cudagraphs_allowed(vllm_config)
):
return AttentionCGSupport.NEVER
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA graph support should come with this PR.

model_version="deepseek_v4",
)

def _forward_sparse_mla_swa_decode_triton(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate file

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 6, 2026

I'm closing this one in favor of #41834
Thank you all for helping me make it. The old branch still remains, and please submit any issues to my fork.

@jasl jasl closed this May 6, 2026
@github-project-automation github-project-automation Bot moved this to Done in NVIDIA May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.